Conversation
Signed-off-by: mikail <mkhona@nvidia.com>
Greptile SummaryThis PR fixes a critical bug in
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant ObliqueSGD/ObliqueAdam
participant WeightDecayMixin
participant _compute_riemannian_grad
Caller->>ObliqueSGD/ObliqueAdam: step()
ObliqueSGD/ObliqueAdam->>ObliqueSGD/ObliqueAdam: update momentum buffer in-place<br/>torch.add(grad, buf, alpha=mom, out=buf)
ObliqueSGD/ObliqueAdam->>_compute_riemannian_grad: (param, buf/norm_grad, dim)
_compute_riemannian_grad-->>ObliqueSGD/ObliqueAdam: riem_grad (tangent-space projected)
ObliqueSGD/ObliqueAdam->>WeightDecayMixin: _apply_weight_decay_inplace(param, riem_grad, lr, wd)
Note over WeightDecayMixin: decoupled: param *= (1 - lr*wd)<br/>independent: param *= (1 - wd)<br/>l2: riem_grad += wd*param
WeightDecayMixin-->>ObliqueSGD/ObliqueAdam: param or riem_grad updated in-place
ObliqueSGD/ObliqueAdam->>ObliqueSGD/ObliqueAdam: param.add_(riem_grad, alpha=-lr)
ObliqueSGD/ObliqueAdam->>ObliqueSGD/ObliqueAdam: normalize(param) — retract to manifold
|
| torch.testing.assert_close( | ||
| optimizer.state[param]["momentum_buffer"], | ||
| expected_buffer, | ||
| atol=0, | ||
| rtol=0, | ||
| ) |
There was a problem hiding this comment.
Strict zero-tolerance assertion may be fragile on second step
The second momentum-buffer assertion uses atol=0, rtol=0 against expected_buffer = second_grad + 0.8 * first_grad. The optimizer computes buf.mul_(0.8).add_(second_grad) while the expected value is computed as second_grad + (0.8 * first_grad) — two different Python/PyTorch expressions. Float32 addition is commutative (a + b == b + a) so the values are identical in this specific case, but the ordering of operations differs and could diverge on other hardware/precision modes.
Consider using a small tolerance to make the test more robust:
| torch.testing.assert_close( | |
| optimizer.state[param]["momentum_buffer"], | |
| expected_buffer, | |
| atol=0, | |
| rtol=0, | |
| ) | |
| torch.testing.assert_close( | |
| optimizer.state[param]["momentum_buffer"], | |
| expected_buffer, | |
| atol=1e-6, | |
| rtol=1e-6, | |
| ) |
skyw
left a comment
There was a problem hiding this comment.
some minor things. otherwise LGTM
emerging_optimizers/riemannian_optimizers/normalized_optimizer.py
Outdated
Show resolved
Hide resolved
emerging_optimizers/riemannian_optimizers/normalized_optimizer.py
Outdated
Show resolved
Hide resolved
| riem_grad = _compute_riemannian_grad(param, buf, dim) | ||
|
|
||
| # Apply the weight update | ||
| param.mul_(1 - lr * wd) |
There was a problem hiding this comment.
Q: would addmm be enough for this?
There was a problem hiding this comment.
this step is done this way for literally every other optimizer in this repo
There was a problem hiding this comment.
Remind me the reason? I vaguely remember one of the weight decay type can't be done in single addmm
There was a problem hiding this comment.
I also forgot, but we made the opt_mixin for this
emerging_optimizers/riemannian_optimizers/normalized_optimizer.py
Outdated
Show resolved
Hide resolved
| """Test that ObliqueSGD persists momentum state across optimization steps.""" | ||
| param = torch.tensor( | ||
| [[1.0, 0.0], [0.0, 1.0]], | ||
| dtype=torch.float32, |
There was a problem hiding this comment.
nit: dtype=torch.float32 is not necessary. default dtype is almost never changed.
Signed-off-by: mikail <mkhona@nvidia.com>
|
/ok to test ec6b492 |
Signed-off-by: mikail <mkhona@nvidia.com>
|
/ok to test 7091f0c |
Address issue #136